{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b84e1e81",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81223461a62f4228ae63bef028fd2277",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8be0ab862834478e950764925c8a6390",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "data/test-00000-of-00001.parquet:   0%|          | 0.00/415k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f469fe93bcc4238ad180794f07d6131",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "data/train-00000-of-00001.parquet:   0%|          | 0.00/60.9M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "f = snapshot_download(\n",
    "    repo_id=\"mesolitica/CompA-R-Instructions\",\n",
    "    repo_type='dataset',\n",
    "    allow_patterns=\"data/*.parquet\",\n",
    "    local_dir=\"./CompA-R-Instructions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "acda3802",
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "from multiprocess import Pool\n",
    "import itertools\n",
    "import zipfile\n",
    "import os\n",
    "\n",
    "def chunks(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield (l[i: i + n], i // n)\n",
    "\n",
    "\n",
    "def multiprocessing(strings, function, cores=6, returned=True):\n",
    "    df_split = chunks(strings, len(strings) // cores)\n",
    "    pool = Pool(cores)\n",
    "    pooled = pool.map(function, df_split)\n",
    "    pool.close()\n",
    "    pool.join()\n",
    "\n",
    "    if returned:\n",
    "        return list(itertools.chain(*pooled))\n",
    "\n",
    "def loop(files):\n",
    "    files, _ = files\n",
    "    for zip_file_path in tqdm(files):\n",
    "        destination_folder = './'\n",
    "        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
    "            zip_ref.extractall(destination_folder)\n",
    "        os.remove(zip_file_path)\n",
    "\n",
    "# files = glob('*.zip')\n",
    "# if len(files):\n",
    "#     multiprocessing(files, loop, cores = min(len(files), 20), returned = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "356cfc51",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "05a611b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = AutoProcessor.from_pretrained(\"Qwen/Qwen2-Audio-7B-Instruct\")\n",
    "tokenizer = processor.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3f73f0fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "import pandas as pd\n",
    "\n",
    "rows = []\n",
    "\n",
    "files = glob('CompA-R-Instructions/data/train*.parquet')\n",
    "for f in files:\n",
    "    df = pd.read_parquet(f).to_dict(orient = 'records')\n",
    "    rows.extend(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07a8ff8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "198648"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1f0fe0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "\n",
    "def loop(rows):\n",
    "    rows, _ = rows\n",
    "    data = []\n",
    "    for r in tqdm(rows):\n",
    "        f = r['audio_filename']\n",
    "        if not os.path.exists(f):\n",
    "            continue\n",
    "            \n",
    "        try:\n",
    "            conversation = [\n",
    "                {\"role\": \"user\", \"content\": [\n",
    "                    {\"type\": \"audio\", \"audio_url\": \"audio.wav\"},\n",
    "                    {\"type\": \"text\", \"text\": r['question']},\n",
    "                ]},\n",
    "                {\"role\": \"assistant\", \"content\": r['answer']},\n",
    "            ]\n",
    "            text = processor.apply_chat_template(conversation, tokenize=False)\n",
    "        except Exception as e:\n",
    "            continue\n",
    "        \n",
    "\n",
    "        data.append({\n",
    "            'text': text,\n",
    "            'audio': f,\n",
    "        })\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b667c078",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|█████████████████████████████████████████████████████▌                           | 4379/6621 [00:00<00:00, 5666.81it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5553.33it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5413.46it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5506.88it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5524.07it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5531.58it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5510.86it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5484.75it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5509.35it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5476.37it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5473.02it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5453.84it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5577.12it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5595.22it/s]\n",
      "  0%|                                                                                              | 0/6621 [00:00<?, ?it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5550.08it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5573.35it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5591.77it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5037.09it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5549.40it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5538.01it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 5184.91it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5586.99it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5553.25it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5591.93it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5572.22it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5282.15it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5585.10it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 4909.23it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5534.83it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████| 6621/6621 [00:01<00:00, 5534.26it/s]\n"
     ]
    }
   ],
   "source": [
    "processed = multiprocessing(rows, loop, cores = 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b7e7d3c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "198648"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(processed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e30cd7eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('prepare-CompA-R-Instructions.json', 'w') as fopen:\n",
    "    json.dump(processed, fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a5971c69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\\nAnalyze the frequency and duration of the revving sounds in the audio. Based on these characteristics, infer the type of vehicle producing these sounds.<|im_end|>\\n<|im_start|>assistant\\nThe frequent and lengthy revving sounds suggest a powerful vehicle, likely a race car or motorcycle, which fits the context of a race car event.<|im_end|>\\n',\n",
       " 'audio': 'compa_r_train_audios-mp3/YBaw0jIZ0STo.mp3'}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processed[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
